from django.db import transaction
from django.db import connection
from django.core.management.base import CommandParser
from django.core.management.commands import migrate

from ...models import Tenant
from ...utils.db import schema_exists, create_schema


class Command(migrate.Command):

    def _apply_migrate(self, only_public: bool, *args, **options):
        all_schemas = ['public']

        if not only_public:
            # 一次性迁移所有租户的Schema
            all_schemas += list(Tenant.objects.values_list('schema_name', flat=True))

        for schema in all_schemas:
            print(f'> 正在同步 {schema}')

            # 确保Schema已经存在
            if not schema_exists(schema):
                print(f'{schema}不存在，创建中')
                create_schema(schema)

            # 将当前Schema设置为唯一的目标Schema
            connection.set_schema(schema, include_public=schema != 'public')

            # 运行一次Django的migrate命令
            super().handle(*args, **options)

            # 重置数据库连接
            try:
                transaction.commit()
                connection.close()
                connection.connection = None
            except transaction.TransactionManagementError:
                pass

    def add_arguments(self, parser: CommandParser):
        super().add_arguments(parser)
        parser.add_argument('--public', choices=['yes', 'no'], default='no')

    def handle(self, *args, **options):
        only_public = options['public'].lower().strip() == 'yes'
        self._apply_migrate(only_public, *args, **options)
